{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic Operations on Bayesian Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows examples of some basic operations that can be performed on a Bayesian Network. We use the Protein Signalling network from the bnlearn repository as the example model: https://www.bnlearn.com/bnrepository/discrete-medium.html#sachs\n", "\n", "\n", "The `BayesianNetwork` class in pgmpy inherits the `networkx.DiGraph` class, hence all the methods defined for `networkx.DiGraph` should also work for `BayesianNetwork`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pprint\n", "from IPython.display import Image\n", "import networkx as nx\n", "from pgmpy.factors.discrete import TabularCPD\n", "\n", "# Load the sachs model. \n", "# For other ways to define a model, please refer: https://pgmpy.org/examples/Creating%20a%20Discrete%20Bayesian%20Network.html\n", "from pgmpy.utils import get_example_model\n", "sachs_model = get_example_model('sachs')\n", "\n", "# Visualize the model\n", "viz = sachs_model.to_graphviz()\n", "viz.draw('sachs.png', prog='dot')\n", "Image('sachs.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attributes of the Model Structure" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Nodes: ['Akt', 'Erk', 'Jnk', 'Mek', 'P38', 'PIP2', 'PIP3', 'PKA', 'PKC', 'Plcg', 'Raf'] \n", "\n", "Edges: [('Erk', 'Akt'), ('Mek', 'Erk'), ('PIP3', 'PIP2'), ('PKA', 'Akt'), ('PKA', 'Erk'), ('PKA', 'Jnk'), ('PKA', 'Mek'), ('PKA', 'P38'), ('PKA', 'Raf'), ('PKC', 'Jnk'), ('PKC', 'Mek'), ('PKC', 'P38'), ('PKC', 'PKA'), ('PKC', 'Raf'), ('Plcg', 'PIP2'), ('Plcg', 'PIP3'), ('Raf', 'Mek')] \n", "\n", "Parents of Akt: ['Erk', 'PKA'] \n", "\n", "Children of PKA: ['Akt', 'Erk', 'Jnk', 'Mek', 'P38', 'Raf'] \n", "\n", "Leaf nodes in the model: ['Akt', 'Jnk', 'P38', 'PIP2'] \n", "\n", "Root nodes in the model: ['PKC', 'Plcg'] \n", "\n" ] } ], "source": [ "# Get all the nodes/random variables in the model\n", "all_nodes = sachs_model.nodes()\n", "print(f\"Nodes: {all_nodes} \\n\")\n", "\n", "# Get all the edges in the model.\n", "all_edges = sachs_model.edges()\n", "print(f\"Edges: {all_edges} \\n\")\n", "\n", "# Get all the CPDs.\n", "all_cpds = sachs_model.get_cpds()\n", "\n", "# Get parents of a specific node\n", "akt_parents = sachs_model.get_parents('Akt')\n", "print(f\"Parents of Akt: {akt_parents} \\n\")\n", "\n", "# Get children of a specific node\n", "pka_children = sachs_model.get_children('PKA')\n", "print(f\"Children of PKA: {pka_children} \\n\")\n", "\n", "# Get all the leaf nodes of the model\n", "leaves = sachs_model.get_leaves()\n", "print(f\"Leaf nodes in the model: {leaves} \\n\")\n", "\n", "# Get the root nodes of the model\n", "roots = sachs_model.get_roots()\n", "print(f\"Root nodes in the model: {roots} \\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modifying the Model Structure" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Adding nodes to the model.\n", "sachs_model.add_node('new_node')\n", "sachs_model.add_nodes_from(['new_node1', 'new_node2'])\n", "\n", "# Adding edges to the model.\n", "sachs_model.add_edge('Akt', 'new_node')\n", "sachs_model.add_edges_from([('Akt', 'new_node1'), ('Akt', 'new_node2')])\n", "\n", "# Removing edges from the model.\n", "sachs_model.remove_edge('Akt', 'new_node')\n", "sachs_model.remove_edges_from([('Akt', 'new_node1'), ('Akt', 'new_node2')])\n", "\n", "# Removing nodes from the model\n", "sachs_model.remove_node('new_node')\n", "sachs_model.remove_nodes_from(['new_node1', 'new_node2'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# At any point, check_model can be called to check if the specified model is correct.\n", "sachs_model.check_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modifying associated parameterization" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Getting an associated CPD\n", "sachs_model.get_cpds('Akt')\n", "\n", "# Adding new CPDs to the model\n", "sachs_model.add_node('new_node')\n", "new_cpd = TabularCPD('new_node', 2, [[0.2], [0.8]])\n", "sachs_model.add_cpds(new_cpd)\n", "\n", "# Removing the CPD and the node\n", "sachs_model.remove_cpds('new_node')\n", "sachs_model.remove_node('new_node')\n", "\n", "sachs_model.check_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## D-Separation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n", "True\n", "False\n" ] } ], "source": [ "# Check if two variables in the network are conditionally / unconditionally d-connected.\n", "print(sachs_model.is_dconnected('PKC', 'Akt'))\n", "print(sachs_model.is_dconnected('PKC', 'Akt', observed=['Mek']))\n", "print(sachs_model.is_dconnected('PKC', 'Akt', observed=['Mek', 'PKA']))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'PKA': {'Akt', 'Mek', 'PKC', 'Jnk', 'P38', 'Erk', 'PKA', 'Raf'}}\n", "{'PKA': {'Akt', 'Mek', 'PKC', 'Jnk', 'P38', 'Erk', 'PKA', 'Raf'}, 'Raf': {'Akt', 'Mek', 'Jnk', 'P38', 'Erk', 'PKA', 'PKC', 'Raf'}}\n", "\n", "{'PKA': {'Akt', 'Erk', 'Jnk', 'P38', 'Raf', 'PKA'}}\n", "{'PKA': {'Akt', 'Erk', 'Jnk', 'P38', 'Raf', 'PKA'}, 'Raf': {'Akt', 'Erk', 'Jnk', 'P38', 'Raf', 'PKA'}}\n" ] } ], "source": [ "# List all the variables that are d-connected to a given variable.\n", "print(sachs_model.active_trail_nodes('PKA'))\n", "print(sachs_model.active_trail_nodes(['PKA', 'Raf']))\n", "\n", "print()\n", "\n", "# List all d-connected variables when conditioned on some other variables\n", "print(sachs_model.active_trail_nodes('PKA', observed=['Mek', 'PKC']))\n", "print(sachs_model.active_trail_nodes(['PKA', 'Raf'], observed=['Mek', 'PKC']))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'Erk', 'PKA'}\n" ] } ], "source": [ "# Find the minimal d-separator of any two variables\n", "print(sachs_model.minimal_dseparator('PKC', 'Akt'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other Methods" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['PKA', 'Mek', 'PKC']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the Markov blanket of a variable\n", "sachs_model.get_markov_blanket('Raf')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Raf ⟂ PIP3, Plcg, Jnk, P38, PIP2 | PKA, PKC)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List all local indpeendencies of a node\n", "sachs_model.local_independencies('Raf')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(Akt ⟂ PIP3, PIP2, Plcg),\n", " (Akt ⟂ PIP3, PIP2, Plcg | Mek),\n", " (Akt ⟂ PIP3, PIP2, Plcg | Jnk),\n", " (Akt ⟂ PIP3, PIP2, Plcg | P38),\n", " (Akt ⟂ PIP3, PIP2, Plcg | Erk),\n", " (Akt ⟂ Plcg, PIP3 | PIP2),\n", " (Akt ⟂ PIP3, PIP2, Plcg | PKA),\n", " (Akt ⟂ Plcg, PIP2 | PIP3),\n", " (Akt ⟂ PIP3, PIP2 | Plcg),\n", " (Akt ⟂ PIP3, PIP2, Plcg | PKC)]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List all implied independencies in the network\n", "sachs_model.get_independencies().independencies[:10]" ] } ], "metadata": { "kernelspec": { "display_name": "pgmpy", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 4 }